{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train word2vec model with Gensim\n",
    "\n",
    "In this example we'll use Gensim to train a word2vec model over the word corpus of the novel War and Peace by Leo Tolstoy. We'll use `nltk` to tokenize the sentences of the novel."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start with the imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "import pprint  # beautify prints\n",
    "\n",
    "import gensim\n",
    "import nltk\n",
    "# nltk.download('punkt')  # Uncomment this to download the punkt tokenizer\n",
    "\n",
    "logging.basicConfig(level=logging.INFO)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's define and instantiate the `TokenizedSentences` class, which splits the text into sentences and then tokenizes each sentence using the `nltk` tokenizer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TokenizedSentences:\n",
    "    \"\"\"Split text to sentences and tokenize them\"\"\"\n",
    "\n",
    "    def __init__(self, filename: str):\n",
    "        self.filename = filename\n",
    "\n",
    "    def __iter__(self):\n",
    "        with open(self.filename) as f:\n",
    "            corpus = f.read()\n",
    "\n",
    "        raw_sentences = nltk.tokenize.sent_tokenize(corpus)\n",
    "        for sentence in raw_sentences:\n",
    "            if len(sentence) > 0:\n",
    "                yield gensim.utils.simple_preprocess(sentence, min_len=2, max_len=15)\n",
    "\n",
    "\n",
    "sentences = TokenizedSentences('war_and_peace.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll instantiate and train the wor2vec model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:gensim.models.word2vec:collecting all words and their counts\n",
      "INFO:gensim.models.word2vec:PROGRESS: at sentence #0, processed 0 words, keeping 0 word types\n",
      "INFO:gensim.models.word2vec:PROGRESS: at sentence #10000, processed 154865 words, keeping 9561 word types\n",
      "INFO:gensim.models.word2vec:PROGRESS: at sentence #20000, processed 329560 words, keeping 13951 word types\n",
      "INFO:gensim.models.word2vec:PROGRESS: at sentence #30000, processed 507242 words, keeping 16748 word types\n",
      "INFO:gensim.models.word2vec:collected 17433 word types from a corpus of 551017 raw words and 32040 sentences\n",
      "INFO:gensim.models.word2vec:Loading a fresh vocabulary\n",
      "INFO:gensim.models.word2vec:effective_min_count=5 retains 6538 unique words (37% of original 17433, drops 10895)\n",
      "INFO:gensim.models.word2vec:effective_min_count=5 leaves 531623 word corpus (96% of original 551017, drops 19394)\n",
      "INFO:gensim.models.word2vec:deleting the raw counts dictionary of 17433 items\n",
      "INFO:gensim.models.word2vec:sample=0.001 downsamples 49 most-common words\n",
      "INFO:gensim.models.word2vec:downsampling leaves estimated 385800 word corpus (72.6% of prior 531623)\n",
      "INFO:gensim.models.base_any2vec:estimated required memory for 6538 words and 100 dimensions: 8499400 bytes\n",
      "INFO:gensim.models.word2vec:resetting layer weights\n",
      "INFO:gensim.models.base_any2vec:training model with 3 workers on 6538 vocabulary and 100 features, using sg=1 hs=0 sample=0.001 negative=5 window=5\n",
      "INFO:gensim.models.base_any2vec:EPOCH 1 - PROGRESS: at 1.95% examples, 4411 words/s, in_qsize 0, out_qsize 0\n",
      "INFO:gensim.models.base_any2vec:EPOCH 1 - PROGRESS: at 62.46% examples, 87680 words/s, in_qsize 0, out_qsize 0\n",
      "INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 2 more threads\n",
      "INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 1 more threads\n",
      "INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 0 more threads\n",
      "INFO:gensim.models.base_any2vec:EPOCH - 1 : training on 551017 raw words (385768 effective words) took 3.3s, 117858 effective words/s\n",
      "INFO:gensim.models.base_any2vec:EPOCH 2 - PROGRESS: at 1.95% examples, 4477 words/s, in_qsize 0, out_qsize 0\n",
      "INFO:gensim.models.base_any2vec:EPOCH 2 - PROGRESS: at 68.19% examples, 98061 words/s, in_qsize 0, out_qsize 0\n",
      "INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 2 more threads\n",
      "INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 1 more threads\n",
      "INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 0 more threads\n",
      "INFO:gensim.models.base_any2vec:EPOCH - 2 : training on 551017 raw words (385724 effective words) took 3.1s, 123661 effective words/s\n",
      "INFO:gensim.models.base_any2vec:EPOCH 3 - PROGRESS: at 1.95% examples, 4847 words/s, in_qsize 0, out_qsize 0\n",
      "INFO:gensim.models.base_any2vec:EPOCH 3 - PROGRESS: at 69.88% examples, 103019 words/s, in_qsize 1, out_qsize 0\n",
      "INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 2 more threads\n",
      "INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 1 more threads\n",
      "INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 0 more threads\n",
      "INFO:gensim.models.base_any2vec:EPOCH - 3 : training on 551017 raw words (385237 effective words) took 3.0s, 128758 effective words/s\n",
      "INFO:gensim.models.base_any2vec:EPOCH 4 - PROGRESS: at 1.95% examples, 4723 words/s, in_qsize 1, out_qsize 0\n",
      "INFO:gensim.models.base_any2vec:EPOCH 4 - PROGRESS: at 73.02% examples, 108399 words/s, in_qsize 0, out_qsize 1\n",
      "INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 2 more threads\n",
      "INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 1 more threads\n",
      "INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 0 more threads\n",
      "INFO:gensim.models.base_any2vec:EPOCH - 4 : training on 551017 raw words (385755 effective words) took 2.9s, 131192 effective words/s\n",
      "INFO:gensim.models.base_any2vec:EPOCH 5 - PROGRESS: at 1.95% examples, 4893 words/s, in_qsize 0, out_qsize 0\n",
      "INFO:gensim.models.base_any2vec:EPOCH 5 - PROGRESS: at 68.19% examples, 103292 words/s, in_qsize 0, out_qsize 0\n",
      "INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 2 more threads\n",
      "INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 1 more threads\n",
      "INFO:gensim.models.base_any2vec:worker thread finished; awaiting finish of 0 more threads\n",
      "INFO:gensim.models.base_any2vec:EPOCH - 5 : training on 551017 raw words (386083 effective words) took 3.0s, 126758 effective words/s\n",
      "INFO:gensim.models.base_any2vec:training on a 2755085 raw words (1928567 effective words) took 15.4s, 125289 effective words/s\n"
     ]
    }
   ],
   "source": [
    "model = gensim.models.word2vec. \\\n",
    "    Word2Vec(sentences=sentences,\n",
    "             sg=1,  # 0 for CBOW and 1 for Skip-gram\n",
    "             size=100,  # size of the embedding vector\n",
    "             window=5,  # the size of the context window\n",
    "             negative=5,  # negative sampling word count\n",
    "             min_count=5,  # minimal word occurrences to include\n",
    "             iter=5,  # number of epochs\n",
    "             )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To see the result of the training, we'll query the model for the most similar words to `mother`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:gensim.models.keyedvectors:precomputing L2-norms of word weight vectors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Words most similar to 'mother':\n",
      "[('brother', 0.8962684869766235),\n",
      " ('daughter', 0.8957958221435547),\n",
      " ('father', 0.8954352140426636),\n",
      " ('sister', 0.8898148536682129),\n",
      " ('husband', 0.876844048500061)]\n"
     ]
    }
   ],
   "source": [
    "print(\"Words most similar to 'mother':\")\n",
    "pprint.pprint(model.wv.most_similar(positive='mother', topn=5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also do the same for a combination of words. Let's try `woman` + `king` to see if the result would be `queen`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Words most similar to 'woman' and 'king':\n",
      "[('admirable', 0.9163172245025635),\n",
      " ('heiress', 0.9129906296730042),\n",
      " ('queen', 0.9082918167114258),\n",
      " ('providence', 0.9049242734909058),\n",
      " ('creature', 0.903453528881073)]\n"
     ]
    }
   ],
   "source": [
    "print(\"Words most similar to 'woman' and 'king':\")\n",
    "pprint.pprint(model.wv.most_similar(positive=['woman', 'king'], topn=5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Indeed, one of the most similar words is `queen`. However, other words like `creature` are not relevant. Perhaps, we should train the model with larger training dataset."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
